I want to approach the question of inclusive machine learning systems by analyzing a fun, open and global data set collected by Quick, Draw! With this data set, I want to answer the following questions:
The corresponding blog post is accessible on Medium - link
In the below notebook, I run the notebook for the different categories (Cake, Cookie, Ice Cream, Sandwich)
Note that due to the volume of data sets used from Quick Draw, there is no data stored in this repo.
Results and discussion can be found in the post here
The data set and different preprocessed formats are explained in Quick, Draw! github
The raw data is available as ndjson files seperated by category, in the following format:
| Key | Type | Description |
|---|---|---|
| key_id | 64-bit unsigned integer | A unique identifier across all drawings. |
| word | string | Category the player was prompted to draw. |
| recognized | boolean | Whether the word was recognized by the game. |
| timestamp | datetime | When the drawing was created. |
| countrycode | string | A two letter country code (ISO 3166-1 alpha-2) of where the player was located. |
| drawing | string | A JSON array representing the vector drawing |
The format of the drawing array is as following:
[
[ // First stroke
[x0, x1, x2, x3, ...],
[y0, y1, y2, y3, ...],
[t0, t1, t2, t3, ...]
],
[ // Second stroke
[x0, x1, x2, x3, ...],
[y0, y1, y2, y3, ...],
[t0, t1, t2, t3, ...]
],
... // Additional strokes
]
The Simplified Drawing format used here has simplified vectors, no more timing information in drawing array, and the data has been positioned and scaled into a 256x256 region.The simplification process was:
# Import libraries necessary for this project
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.offline as py
import plotly.graph_objs as go
%matplotlib inline
py.init_notebook_mode()
%%time
# Load Simplified ndjson for the different categories. Datasets are too big to load them all at once (Notebook crashed).
# I have downloaded the data sets locally and loading them here.
category = "ice_cream" # "cake", "cookie", "sandwich"
filepath = "full_simplified_{}.ndjson".format(category)
df = pd.read_json(filepath, lines=True)
df.info()
df.head(10)
We will create several new features on top of existing data to support our investigations
df['timestamp'].dt.year.unique()
# Timestamp datatime checks
df.groupby(df['timestamp'].dt.month).count()
All drawings are timestamped January or March 2017 (for all 4 categories)
Let's add more geographical information to be able to answer our first question around where do the doodles come from (add iso alpha 3 codes and region).
# Let's look at countrycode. We need real country name and ISO alpha-3 codes later on (plotly)
# Number of unique countries
df.countrycode.nunique()
# Load ISO country code - cf QuickDraw format => A two letter country code (ISO 3166-1 alpha-2)
# country_ISO.csv contains country name, alpha-2, alpha-3 and region information
# /!\ keep_default_na is set to False to avoid interpreting Namibia country code 'NA' as NaN...
ISO_countries = pd.read_csv('country_ISO.csv', sep=',', names=['country', 'countrycode', 'countrycode3', 'region'], header=0, keep_default_na=False)
ISO_countries.head(5)
# Common values between ISO countries data set and Quick Draw dataset
ISO_countries.countrycode.isin(df.countrycode).value_counts()
Some country codes are not matching.
# Let's merge both dataframe on countrycode using type left
# Left means we keep every row in the left dataframe, Quick Draw dataframe here.
# Where there are missing values on countrycode in the right dataframe, add empty / NaN values in the result.
df_c = pd.merge(df, ISO_countries, on='countrycode', how='left')
# Look at missing values in merged dataframe
df_c[df_c.country.isnull()].countrycode.value_counts()
Looking at https://en.wikipedia.org/wiki/ISO_3166-1_alpha-2 or https://www.iso.org/obp/ui/#search we can see that some country codes like ZZ (user-assigned code) or BU or AN have now been changed and no longer in use
df_c.head(5)
# drop rows where country = NaN
df_c = df_c.dropna(subset=['country'])
df_c.shape
# count and percentage of drawings per country
geo_dist_count = df_c.countrycode3.value_counts()
geo_dist_perc = (geo_dist_count/len(df_c)*100)
# number of unique identified countries
df_c.countrycode3.nunique()
# countries with minimum of 1000 drawings for the category
geo_dist_count[geo_dist_count>=1000]
# with plotly draw map showing countries of players and percentage of drawing per country
data = dict(
type = 'choropleth',
locations = geo_dist_perc.index,
z = geo_dist_perc,
colorbar = {'ticksuffix': '%'}
)
layout = dict(
title = 'Geographical distribution - Category: {}'.format(category),
geo = dict(
showframe = False,
projection = {'type':'equirectangular'}
)
)
choromap = go.Figure(data = [data],layout = layout)
py.iplot(choromap)
#show geographies and percentage of drawing with a countplot
f, ax = plt.subplots(1,1, figsize=(16,4))
total = float(len(df_c))
# select countries with minimum 1000 drawings
df_sub = df_c[df_c.countrycode3.isin(geo_dist_count[geo_dist_count>=1000].index)]
g = sns.countplot(df_sub['countrycode3'], order = df_sub['countrycode3'].value_counts().index, palette='Set3')
g.set_title("Number and percentage of drawings per country for {} (countries with >= 1000 drawings)".format(category))
for p in ax.patches:
height = p.get_height()
ax.text(p.get_x()+p.get_width()/2.,
height + 3,
'{:1.2f}%'.format(100*height/total),
ha="center")
plt.show()
# main_countries = countries with min 1000 drawing for the category
main_countries = geo_dist_count[geo_dist_count>=1000].index
df_q2 = df_c[df_c.countrycode3.isin(main_countries)].copy()
len(main_countries)
df_q2.info()
# Add new feature stroke_number - count of strokes in drawings
df_q2['stroke_number']=df_q2['drawing'].str.len()
# Map recognized to 0 or 1
df_q2['recognized']=df_q2['recognized'].map({True: 1, False: 0})
# Add total number of points (x,y) for each drawing.
total_np = []
for i in df_q2.index:
# we only need to count x as we have same number of corresponding y
# we get x's values for each stroke in each drawing in lists
X = [df_q2.loc[i,'drawing'][stroke][0] for stroke in range(df_q2.stroke_number[i])]
# then flatten the lists to count total number of points
flat_X = [item for stroke in X for item in stroke]
total_np.append(len(flat_X))
df_q2['total_np'] = total_np
# Add another feature Direction - Radian angle of first stroke
fs_direction = []
fs_dp = []
for i in df_q2.index:
first_stroke = df_q2.loc[i,'drawing'][0]
fs_dp.append(len(first_stroke[0]))
dx = first_stroke[0][1] - first_stroke[0][0]
if dx == 0: dx=0.000001
dy = first_stroke[1][1] - first_stroke[1][0]
# radian value (0 to 6.28)
if dy < 0.0 and dx > 0.0:
rad = (2*np.pi + np.arctan(dy/dx))
elif dy >=0.0 and dx > 0.0:
rad = (np.arctan(dy/dx))
else:
rad = np.pi + np.arctan(dy/dx)
fs_direction.append(rad)
df_q2['fs_direction'] = fs_direction
df_q2['fs_dp'] = fs_dp
df_q2.head(10)
With these addtional features for our main countries we can now look for any patterns linked to geography or culture
df_q2.groupby('countrycode')[['recognized','stroke_number']].describe(include='all')
df_q2.columns
# Show a Pair Grid of the different features just created to have a summary and visual idea of the data
g = sns.pairplot(df_q2[['recognized','countrycode3', 'stroke_number', 'total_np','fs_direction', 'fs_dp']], hue="countrycode3", diag_kind="hist")
for ax in g.axes.flat:
plt.setp(ax.get_xticklabels(), rotation=45)
# Distribution of stroke number
sns.set_style("whitegrid", {'axes.grid' : False})
plt.figure(figsize=(16, 8)).suptitle("Distribution of strokes number for category: {}".format(category), fontsize=16)
for idx, ct in enumerate(main_countries):
x=df_q2.loc[df_q2.countrycode3==ct,'stroke_number']
#x=df_q2.loc[df_q2.countrycode==ct,'total_np']
sns.distplot(x, hist=False, rug=False, label=ct);
# Look at recognized rate per country
with plt.style.context('seaborn-whitegrid'):
plt.figure(figsize=(15, 6)).suptitle("Percentage of recognition for category: {}".format(category), fontsize=16)
plt.ylabel('Percentage')
(df_q2.groupby('countrycode3')['recognized'].mean()*100).sort_values().plot.bar()
# Overlap Drawing for recognized drawings of our main countries
# this gives an idea of what the ML system recognize as an icecream
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(20,15), sharex='col', sharey='row')
np.vectorize(lambda ax:ax.axis('off'))(ax)
i = 0
for j,country in enumerate(main_countries):
d_recog = df_q2[(df_q2.countrycode3==country) & (df_q2.recognized==1)]
drawings = d_recog.drawing.sample(n=150)
for drawing in drawings:
for strokes in drawing:
ax[i,(j%5)].set_title(country)
ax[i,(j%5)].plot(np.array(strokes[0]), -np.array(strokes[1]), color=(0.6, 0.6, 0.6), alpha=0.5, linewidth=0.5)
i=int((j+1)/5)
# How many unrecognized drawing we have per country?
df_q2[df_q2.recognized==0].countrycode.value_counts()
# Function to display a sample of non-recognized drawings for a country.
def draw_unrecognized(country):
d_tmp = df_q2[(df_q2.countrycode==country) & (df_q2.recognized==0)]
drawing_c = len(d_tmp.index)
if drawing_c > 40:
sample = 40
else:
sample = drawing_c
drawings = d_tmp.drawing.sample(n=sample)
fig, ax = plt.subplots(nrows=8, ncols=5, figsize=(15,15), sharex='col', sharey='row')
for i in range(8):
for j in range(5):
try:
for strokes in drawings.iloc[i*5+j]:
# Each array has X coordinates at [0, :] and Y coordinates at [1, :].
ax[i,j].axis('off')
ax[i,j].plot(np.array(strokes[0]), -np.array(strokes[1]), color='black')
except:
ax[i,j].axis('off')
draw_unrecognized('HU')